import subprocess
import os
import sys
import time

CHECK_METRICS_PERIOD = 5
JVM_ATTACH_PATH = "/opt/gala-gopher/lib/jvm_attach"
SERMANT_PATH = "/tmp/sermant/agent/sermant-agent.jar"
SERMANT_COMMAND_INSTALL = "command=INSTALL-AGENT"
SERMANT_COMMAND_UNINSTALL = "command=UNINSTALL-AGENT"
PIDS_DICT = {}


def print_to_log(msg: str, level="[DEBUG]"):
    print(level + "[SERMANT]:" + msg)
    sys.stdout.flush()


# Checks if the process is a java process and then executes attach
def attach_jar_by_c(pid: int, nspid: int, args: str) -> int:
    pid = str(pid)
    nspid = str(nspid)
    command = JVM_ATTACH_PATH + " " + pid + " " + nspid + ' load instrument false \"' + SERMANT_PATH + "=" + args + "\""
    print_to_log("[pid: %s ] the command is [%s]" % (pid, command))
    process = subprocess.Popen(command, shell=True, stdout=subprocess.PIPE, bufsize=1, universal_newlines=True)

    for line in process.stdout:
        print_to_log("[pid: %s ] the out of jvm_attach is: %s" % (pid, line))

    process.wait()
    return_code = process.returncode
    print_to_log("[pid: %s ] return code is: %d" % (pid, return_code))
    return return_code


def get_nspid(pid: int) -> bool and int:
    # get nspid
    nspid = 0
    nspid_path = '/proc/' + str(pid) + '/status'
    if not os.path.isfile(nspid_path):
        print_to_log("[pid: %s ] file '%s' is not exist!" % (str(pid), nspid_path), "[ERROR]")
        return False, 0

    with open(nspid_path, 'r') as f:
        lines = f.readlines()
        for line in lines:
            line_list = line.rstrip('\n').split('\t')
            if line_list[0] == 'NSpid:' and len(line_list) == 3:
                nspid = line_list[2]
                break
    if nspid == 0:
        print_to_log("[pid: %s ] nspid not found." % (str(pid)), "[ERROR]")
        return False, 0
    return True, nspid


# attach javaAgent
def attach_sermant(pid: int, args: str) -> bool:
    # get comm
    comm = ""
    print_to_log("enter attach_sermant")
    comm_path = '/proc/' + str(pid) + '/comm'
    if os.path.isfile(comm_path):
        with open(comm_path, 'r') as f:
            comm = f.readline().rstrip('\n')
    else:
        print_to_log("[pid: %s ] file '%s' is not exist!" % (str(pid), comm_path), "[ERROR]")
        return False

    print_to_log("enter attach_sermant get comm")
    # check comm
    if comm != "java":
        print_to_log("[pid: %s ] process is not a java process. the comm is %s" % (str(pid), comm), "[ERROR]")
        return False

    # call attach
    get_nspid_result, nspid = get_nspid(pid)
    if get_nspid_result:
        result = attach_jar_by_c(pid=pid, nspid=nspid, args=args)
        return result == 0
    else:
        # nspid is null
        return False


# Modify the tgid field
def modify_tgid(line: str, pid: int) -> bool and str:
    parts = line.split('|')
    if len(parts) >= 3:
        # replace the second subfield with pid
        parts[2] = str(pid)
        # combining into string
        result_str = '|'.join(parts)
        return True, result_str
    else:
        print_to_log("Failed to change the metrics tgid", "[ERROR]")
        return False, ""


# get the data file generated by sermant and print it
def read_file(pid: int, line_num: int) -> bool and int and int:
    write_line = 0
    record_line = int(line_num)
    get_nspid_result, nspid = get_nspid(pid)
    if get_nspid_result:
        result_path = '/proc/' + str(pid) + "/root/tmp/java-data-" + str(nspid) + "/sermant-metrics.txt"
    else:
        result_path = '/proc/' + str(pid) + "/root/tmp/java-data-" + str(pid) + "/sermant-metrics.txt"

    if not os.path.isfile(result_path):
        print_to_log("[pid: %s ] file '%s' is not exist!" % (str(pid), result_path))
        return False, record_line, write_line

    with open(result_path, 'r') as f:
        lines = f.readlines()
        index_i = record_line
        for i, line in enumerate(lines):
            index_i = i
            if i < line_num:
                continue
            modify_result, modify_line = modify_tgid(line, pid)
            if modify_result:
                print(modify_line)
            write_line = write_line + 1
        record_line = index_i + 1

    return True, record_line, write_line


def clear_file(pid: int):
    get_nspid_result, nspid = get_nspid(pid)
    if get_nspid_result:
        result_path = '/proc/' + str(pid) + "/root/tmp/java-data-" + str(nspid) + "/sermant-metrics.txt"
    else:
        result_path = '/proc/' + str(pid) + "/root/tmp/java-data-" + str(pid) + "/sermant-metrics.txt"
    if os.path.exists(result_path):
        with open(result_path, 'w') as f:
            f.truncate(0)


def uninstall_all():
    global PIDS_DICT
    pids_dict_copy = PIDS_DICT.copy()
    for key in pids_dict_copy.keys():
        # read the metrics before uninstall sermant
        read_result, _, _ = read_file(key, pids_dict_copy[key])
        print_to_log("[pid: %s ] last read metrics, result is '%s'!" % (str(key), str(read_result)))
        # clear the sermant-metrics.txt
        clear_file(key)
        # uninstall sermant
        uninstall_result = uninstall_sermant(key)
        print_to_log("[pid: %s ] uninstall sermant, result is '%s'!" % (str(key), str(uninstall_result)))


def check_pid_file():
    global PIDS_DICT
    time.sleep(CHECK_METRICS_PERIOD)
    write_line = 0
    for pid in PIDS_DICT.keys():
        num = PIDS_DICT[pid]
        read_result, record_num, write_line_a = read_file(pid, num)
        PIDS_DICT[pid] = record_num
        write_line = write_line + write_line_a
    print_to_log("metrics write line: " + str(write_line))


# start a thread to perform the collection of metrics periodically
def check_metrics():
    while True:
        try:
            check_pid_file()
        except Exception as e:
            print_to_log("read file exception: " + str(e), "[ERROR]")


def check_and_copy_sermant(pid: int, path: str) -> bool:
    if not os.path.exists("/proc/" + str(pid) + "/root/tmp/sermant/"):
        # copy the sermant project package into the container
        try:
            cp_command = "cp -r " + path + " /proc/" + str(pid) + "/root/tmp/"
            result = subprocess.run(cp_command, shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, text=True)
            print_to_log("copy result code:" + str(result.returncode))
            return result.returncode == 0
        except:
            return False
    return True


def install_sermant(pid: int) -> bool:
    result = attach_sermant(pid, args=SERMANT_COMMAND_INSTALL)
    if result:
        record_install_pid(pid)
    return result


def uninstall_sermant(pid: int) -> bool:
    result = attach_sermant(pid, args=SERMANT_COMMAND_UNINSTALL)
    if result:
        delete_install_success(pid)
    return result


# record the pid for which the sermant is installed
def record_install_pid(pid: int):
    global PIDS_DICT
    PIDS_DICT[pid] = 0


# delete the record of pid with sermant installed for uninstall
def delete_install_success(pid: int):
    global PIDS_DICT
    PIDS_DICT.pop(pid)


# check that the pid has a sermant installed
def check_install_pid(pid: int) -> bool:
    global PIDS_DICT
    return PIDS_DICT.get(pid) is None


# check that the cached pid and uninstall sermant
def check_and_uninstall_pid(pids: list):
    global PIDS_DICT
    print_to_log("check that the cached pid and uninstall sermant", "[INFO]")
    temp_pids_dict = PIDS_DICT.copy()
    for key in temp_pids_dict.keys():
        if key not in pids:
            result = uninstall_sermant(key)
            print_to_log("uninstall Sermant result: " + str(result))


if __name__ == '__main__':
    # test
    check_metrics()
